import PIL
from cv2 import split
from matplotlib import image
from numpy import mask_indices
from image_angle_correct import *
from hsv_test import hsv_split, show,hsv_split_pre
from model import *
import sys
sys.path.append("../")
from util.keys_prov import *
from PIL import Image, ImageDraw, ImageFont

# 训练集路径
# source_image_path = "/home/meroke/code/detection_2022/ocr/image/test/"  # test
source_image_path = "/home/meroke/code/detection_2022/ocr/image/test/"
# source_image_path = "/home/meroke/code/detection/train_images/" # main
# 识别结果图片保存路径
# result_image_save_path = "/home/meroke/code/detection/result/result/" # main
result_image_save_path = "/home/meroke/code/detection_2022/ocr/image/result/" # test
# 未识别地区的图片保存路径
# error_image_save_path ="/home/meroke/code/detection/result/error/" # mian
error_image_save_path ="/home/meroke/code/detection_2022/ocr/image/error/" # test 

# 地区识别结果保存路径
word_detect_path = "/home/meroke/code/detection_2022/ocr/log/result.txt"
# 识别结果细节保存路径
detect_detail_path = "/home/meroke/code/detection_2022/ocr/log/detect_detail.txt"

# 检测正确数量 计数器
detect_count =0
# 所有图片数量 计数器
image_count = 0

'''
description:  在图片上绘制中文信息
param {*} img  绘制图片
param {*} text  要绘制的中文信息
param {*} position 绘制的像素点位置
param {*} textColor 绘制的文字颜色
param {*} textSize 绘制的文字大小
return {*} 返回BGR格式的，绘制有中文信息的图片
'''
def cv2AddChineseText(img, text, position, textColor=(0, 255, 0), textSize=30):
    if (isinstance(img, np.ndarray)):  # 判断是否OpenCV图片类型
        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    # 创建一个可以在给定图像上绘图的对象
    draw = ImageDraw.Draw(img)
    # 字体的格式
    pwd = os.popen("pwd")
    path = pwd.read().strip()
    if path > "/home/tdb2/2022_jsjds/detection_2022":
        path = path[:-4]
    type_path = path + "/ocr/simsun.ttc"
    # print("type_path:",type_path)
    fontStyle = ImageFont.truetype(type_path, textSize, encoding="utf-8")
    # 绘制文本
    draw.text(position, text, textColor, font=fontStyle)
    # 转换回OpenCV格式
    return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)


'''
description: 字符串匹配省份
param {*} detected_word 检测到的字符串 列表
return {*} 返回匹配的省份名称， 无结果则返回空字符串
'''
def searcch_provience(detected_word):
    result = ""
    flag = False
    for key in province:
        if flag:
            break
        for single in detected_word[1]:
            if single.find(province[key])>-1 or single.find(key)>-1:
                result=province[key]
                # print(true_result)
                flag=True
                break
    return result

'''
description: 记录识别到的省份 及 检测出省份的图片名称
param {*} name 检测出省份的图片名称
param {*} result 省份名称
'''
def record_result(name,result):
    global word_detect_path
    if not os.path.exists("log"):
        os.makedirs("log")
    with open(word_detect_path, 'a') as f:
        f.write(result)  #文件的写操作
        f.write(" " + name + ".jpg\r")
'''
description:  记录识别到的所有字符串 及 对应的图片名称
param {*} name 图片名称 
param {*} detail 识别到的所有字符串
'''
def record_detail(name,detail):
    global detect_detail_path
    with open(detect_detail_path, 'a') as f:
        for i in range(len(detail)):
            f.write(detail[i] +" ")  #文件的写操作
        f.write(" " + name + ".jpg\r")


'''
description:  保存省份未匹配成功的图片（缩小尺寸）
param {*} name 图片名称（路径）
param {*} img 图片
'''
def error_img_save(name,img):
    global error_image_save_path
    try:
        img.save(os.path.join(error_image_save_path , name + ".jpg"))
    except:
        cv2.imwrite(os.path.join(error_image_save_path , name+ ".jpg"),img)



"""
    image's box_list index:
    0    1
    3    2
"""
def bigger_box(box,bigger_size=10):
    points = [None]*4
    sorted_box = [None]*4
    # print(box)
    for i in range(len(box)):
        points[i] = box[i]
    points = sorted(points,key= lambda d: d[1],reverse=False)
    list12 = points[:2]
    list34 = points[2:4]
    # print("points",points)
    if list12[0][0] <= list12[1][0]:
        sorted_box[0]=list12[0]
        sorted_box[1]=list12[1]
    else:
        sorted_box[0]=list12[1]
        sorted_box[1]=list12[0]

    if list34[0][0] <= list34[1][0]:
        sorted_box[2]=list34[1]
        sorted_box[3]=list34[0] 
    else:
        sorted_box[2]=list34[0]
        sorted_box[3]=list34[1]  
    sorted_box[0][0] -= bigger_size
    sorted_box[0][1] -= bigger_size

    sorted_box[1][0] += bigger_size
    sorted_box[1][1] -= bigger_size

    sorted_box[2][0] += bigger_size
    sorted_box[2][1] += bigger_size

    sorted_box[3][0] -= bigger_size
    sorted_box[3][1] += bigger_size

    # print("list12",list12)
    # print("list34",list34)
    # print("sorted_box",sorted_box)

    return sorted_box


# 计算轮廓面积，被调用排序轮廓
def cnt_area(cnt):
  area = cv2.contourArea(cnt)
  return area

'''
    TODO 绘制文字框，显示识别结果
        存在漏识别的情况（dbnet 未识别出文本框）
'''
def word_detect_test(path,name):
    global detect_count, image_count, result_image_save_path, error_image_save_path
    print("{} START".format(name))
    img = cv2.imread(path)
    origin_img = img.copy()
    img= cv2.resize(img,(image_size,image_size))
    img = hsv_split(img)

    # 返回识别框的坐标点列表
    box_list, score_list = text_handle.process(img,image_size)
    if len(box_list) < 1:
        print("detect fail {}".format(name))
        error_img_save(name,origin_img)
        return
    # 获取最大识别框下标
    max_index = max_box_index(box_list,img)

    # 绘制展示最大识别框
    img = draw_max_bbox(img, box_list[max_index])
    cv2.imwrite(result_image_save_path +"{}_max_bbox.jpg".format(name), img)
    # 根据固定下标0 1 两点 计算图片应该旋转的角度
    angle = get_angle(img,box_list[max_index])

    # 图像角度矫正
    img = adjust_angle(img,angle)
    kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], np.float32) #定义一个核
    img = cv2.filter2D(img, -1, kernel=kernel) 
    # show("img",img)
    # cv2.imwrite(result_image_save_path +"{}_del.jpg".format(name), img)
    PIL_image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
    # PIL_image = PIL_image.resize((image_size,image_size),Image.ANTIALIAS)
    detected_word = OcrHandle().text_predict(PIL_image,image_size)
    # print(detected_word[1])
    result = searcch_provience(detected_word)
    print("result:",result)
    if result is "":
        error_img_save(name,origin_img)
        PIL_image_copy = PIL_image.transpose(Image.ROTATE_180)
        detected_word_copy = OcrHandle().text_predict(PIL_image_copy,image_size)
        result = searcch_provience(detected_word_copy)
        if result  != "":
            temp = cv2AddChineseText(PIL_image_copy,result,(50,100),(255,0,0),50)
            show(name+"unknown",temp)
    else:
        detect_count +=1
    image_count+=1

    print("detected:{}\r".format(result))
    CV_image = cv2AddChineseText(PIL_image,result,(50,100),(255,0,0),50)

    # 记录结果
    name = name.split('.')[0]
    cv2.imwrite(result_image_save_path +"{}_del.jpg".format(name), CV_image)
    record_result(name,result)
    record_detail(name,detected_word[1])


'''
description: 图像缩放，模型不能适应过大图像，因此进行缩小
param {*} img 传入图像
param {*} scale 指定缩放倍数，默认10倍
return {*}
warning: 窗宽不一致时，进行旋转，图片边缘部分会模糊，极大影响识别，故此处固定长宽
'''
def resize_image(img,scale=10):
    h,w = img.shape[0],img.shape[1]

    img= cv2.resize(img,(image_size,image_size))
    h = image_size
    w = image_size
    # 自适应缩放
    # if(h>1024 or w > 1024):
    #     if h > w:
    #         scale = h /1024
    #     else:
    #         scale = w /1024
    #     h = int(h/scale)
    #     w = int(w/scale)
    #     img= cv2.resize(img,(h,w),interpolation=cv2.INTER_AREA)
    return img,h,w

'''
description:  最终使用识别函数，通过30度微旋转，纠正由于文字框误识别导致的错误旋转 
param {*} img 传入处理图像
param {*} name 图片名称
param {*} origin_img  原始图片，展示错误
param {*} short_size 图片短边
return {*}
'''
# def get_result2(img,name,origin_img):
#     global detect_count, image_count, result_image_save_path, error_image_save_path
#     img,h,w =  resize_image(img)
#     short_size = h if h<w else w
#     # print(w,h,short_size)
#     for angle in range(0,361,90):
#         img = adjust_angle(img,angle)
#         #how("rotated",img)
#         PIL_image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
#         detected_word = OcrHandle().text_predict(PIL_image,short_size)
#         result = searcch_provience(detected_word)
#         # main :log message
#         record_result(name,result)
#         # 记录详细信息
#         record_detail(name,detected_word[1])
#         CV_img = cv2.cvtColor(np.asarray(PIL_image),cv2.COLOR_RGB2BGR)
#         if result != "":
#             # show("result",CV_img,1)
#             # detect_count +=1
#             break
#     if result == "":
#         error_img_save(name,origin_img)
#         error_img_save(name+"mask",img)
#     # image_count += 1
#     return result,detected_word
def get_result2(img,name,origin_img,mode=0):
    global detect_count, image_count, result_image_save_path, error_image_save_path
    img,h,w =  resize_image(img)
    short_size = h if h<w else w
    # print(w,h,short_size)
    result = ""
    if mode == 0:
        # 返回识别框的坐标点列表-文本检测
        box_list, score_list = text_handle.process(img,image_size)
        if len(box_list) < 1:
            print("detect fail {}".format(name))
            error_img_save(name,origin_img)
            return
        # 获取最大识别框下标
        max_index = max_box_index(box_list,img)
        # 根据固定下标0 1 两点 计算图片应该旋转的角度
        angle = get_angle(img,box_list[max_index])
        IF_rotate_180 = 0
        while result == "":
            # 绘制展示最大识别框
            # img = draw_max_bbox(img, box_list[max_index])
            # cv2.imwrite(result_image_save_path +"{}_max_bbox.jpg".format(name), img)
            img = adjust_angle(img,angle)
            PIL_image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
            # 文字识别
            detected_word = OcrHandle().text_predict(PIL_image,short_size)
            result = searcch_provience(detected_word)
            # show("rotation",img,1)
            if IF_rotate_180:
                break
            if result == "":
                angle = 180
                IF_rotate_180 = 1
    else:
        for angle in range(0,361,90):
            img = adjust_angle(img,angle)
            #how("rotated",img)
            PIL_image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
            detected_word = OcrHandle().text_predict(PIL_image,short_size)
            result = searcch_provience(detected_word)
            # main :log message
            record_result(name,result)
            # 记录详细信息
            record_detail(name,detected_word[1])
            CV_img = cv2.cvtColor(np.asarray(PIL_image),cv2.COLOR_RGB2BGR)
            if result != "":
                # show("result",CV_img,1)
                # detect_count +=1
                break
    if result == "":
        error_img_save(name,origin_img)
        error_img_save(name+"mask",img)
    # image_count += 1
    return result,detected_word


'''
description: 预期产出 960*960的拓展边界图片
param {*} img
param {*} box
param {*} extra_box
return {*}
'''
def split_resize_img(img,extra_box):
    [min_x,max_x,min_y,max_y] = extra_box

    h = max_y - min_y
    w = max_x - min_x
    split_img = img[min_y:max_y,min_x:max_x]
    top_bottom_size = (image_size-h) // 2
    left_right_size = (image_size-w) // 2
    img = cv2.copyMakeBorder(split_img,top_bottom_size,top_bottom_size,
    left_right_size,left_right_size, cv2.BORDER_CONSTANT, value=0)
    # cv2.imwrite("split_img.jpg",split_img)
    return img

def detect_prepare(image,mode=0):
    if not mode:
        img = cv2.imread(image)
    else :
        img = image
    # 原图尺寸 1920*1080过大，轮廓检测效果差，故缩小为1080*720
    img= cv2.resize(img,(1080,720))
    print("hsv split test")
    # 涂黑 图片右侧，避免检测第三次
    # img[950:1080,0:720]  = 0
    # get masked_img ,cnt_centre_list ,boxes_list
    img,cnt_centre_list,boxes_list,extra_big_box = hsv_split_pre(img,1080,720)
    return img,cnt_centre_list,boxes_list,extra_big_box

'''
description:  国塞正式识别流程
param {*} path 图片路径  或实际图像
param {*} name  图片名称
'''
def Final_Match_Detect(img,extra_big_box,name):
    global detect_count, image_count, result_image_save_path, error_image_save_path
    '''
        按轮廓框的位置划分图片,一般为2份,
    '''
    # 图像锐化
    kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], np.float32) #定义一个核
    img = cv2.filter2D(img, -1, kernel=kernel) 

    last_right_x = 0
    detect_word_list = []
    # 识别框分离
    for i, box in enumerate(extra_big_box[:2]):
        [min_x,max_x,min_y,max_y] = box
        box_right_x = max_x
        last_right_x = min_x
        if last_right_x < 0:
            last_right_x = 0
        h,w = img.shape[0],img.shape[1]
        img_copy = img.copy()
        img_copy[0:h,  0:last_right_x] = 0
        img_copy[0:h,  box_right_x:w] = 0
        # box = bigger_box(box)
        img_copy = split_resize_img(img_copy,extra_big_box[i])
        # show("img",img_copy,1)
        # cv2.imwrite(result_image_save_path +"{}_del.jpg".format(name), img)
        # PIL_image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
        # PIL_image = PIL_image.resize((image_size,image_size),Image.ANTIALIAS)
        # 文字识别 获取最终识别结果
        # 传入图片尺寸为 1080*720，识别时压缩为960*960
        time.sleep(1)
        # show("img_copy",img_copy,1)
        result,detected_word  = get_result2(img_copy,name+str(i).split('.')[0],img_copy)
        if result is not "":
            detect_count +=1
        # 如果未成功识别，自动标注为安徽
        else:
            result = "安徽"
        image_count +=1
        detect_word_list.append(result)
        print("detected:{}\r".format(result))
        CV_image = cv2AddChineseText(img_copy,result,(50,100),(255,0,0),50)

        # 记录结果
        name = name.split('.')[0]
        cv2.circle(CV_image,(box_right_x,int(h/2)),2,(0,255,0),3)
        cv2.imwrite(result_image_save_path +"{}_del_{}.jpg".format(name,i), CV_image)
        # record_result(name,result)
        # record_detail(name,detected_word[1])
    print(detect_word_list)

    return detect_word_list


def prepare():
    global word_detect_path
    if os.path.exists(word_detect_path):
        os.remove(word_detect_path)
    if os.path.exists(detect_detail_path):
        os.remove(detect_detail_path)
    if os.path.exists(result_image_save_path):
        for i in os.listdir(result_image_save_path):
            os.remove(os.path.join(result_image_save_path,i))
    if os.path.exists(error_image_save_path):
        for i in os.listdir(error_image_save_path):
            os.remove(os.path.join(error_image_save_path,i))
    print("prepare done!")

if __name__ == "__main__":
    
    prepare()
    start_time = time.time()
    for i in os.listdir(source_image_path):
        if "jpg" in i:
            if "del" in i:
                continue
            # hsv_split(source_image_path + i,i)
            img,cnt_centre_list,boxes_list,extra_big_box = detect_prepare(source_image_path + i)
            cnt_centre_list, detect_word_list = Final_Match_Detect(img,cnt_centre_list,boxes_list,extra_big_box,i)
            detect_word_list = detect_word_list[:2]
            detect_word_list.reverse()
            print(detect_word_list)
            # word_detect_test(source_image_path + i,i)
            print(i + " done!")
            # break

    end_time  = time.time()
    print("cost time:{} s".format(end_time - start_time))
    print("cost time:{} m".format((end_time - start_time)/60))

    print("sum:{} detected:{}, rate:{}".format(image_count, detect_count, detect_count/image_count) )